{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Learning Rate Scheduling\n",
    ":label:`sec_scheduler`\n",
    "\n",
    "So far we primarily focused on optimization *algorithms* for how to update the weight vectors rather than on the *rate* at which they are being updated. Nonetheless, adjusting the learning rate is often just as important as the actual algorithm. There are a number of aspects to consider:\n",
    "\n",
    "* Most obviously the *magnitude* of the learning rate matters. If it is too large, optimization diverges, if it is too small, it takes too long to train or we end up with a suboptimal result. We saw previously that the condition number of the problem matters (see e.g., :numref:`sec_momentum` for details). Intuitively it is the ratio of the amount of change in the least sensitive direction vs. the most sensitive one.\n",
    "* Secondly, the rate of decay is just as important. If the learning rate remains large we may simply end up bouncing around the minimum and thus not reach optimality. :numref:`sec_minibatch_sgd` discussed this in some detail and we analyzed performance guarantees in :numref:`sec_sgd`. In short, we want the rate to decay, but probably more slowly than $\\mathcal{O}(t^{-\\frac{1}{2}})$ which would be a good choice for convex problems.\n",
    "* Another aspect that is equally important is *initialization*. This pertains both to how the parameters are set initially (review :numref:`sec_numerical_stability` for details) and also how they evolve initially. This goes under the moniker of *warmup*, i.e., how rapidly we start moving towards the solution initially. Large steps in the beginning might not be beneficial, in particular since the initial set of parameters is random. The initial update directions might be quite meaningless, too.\n",
    "* Lastly, there are a number of optimization variants that perform cyclical learning rate adjustment. This is beyond the scope of the current chapter. We recommend the reader to review details in :cite:`Izmailov.Podoprikhin.Garipov.ea.2018`, e.g., how to obtain better solutions by averaging over an entire *path* of parameters.\n",
    "\n",
    "Given the fact that there is a lot of detail needed to manage learning rates, most deep learning frameworks have tools to deal with this automatically. In the current chapter we will review the effects that different schedules have on accuracy and also show how this can be managed efficiently via a *learning rate scheduler*. \n",
    "\n",
    "In DJL we will be referring to these as learning rate trackers.\n",
    "\n",
    "## Toy Problem\n",
    "\n",
    "We begin with a toy problem that is cheap enough to compute easily, yet sufficiently nontrivial to illustrate some of the key aspects. For that we pick a slightly modernized version of LeNet (`relu` instead of `sigmoid` activation, MaxPooling rather than AveragePooling), as applied to Fashion-MNIST. Moreover, we hybridize the network for performance. Since most of the code is standard we just introduce the basics without further detailed discussion. See :numref:`chap_cnn` for a refresher as needed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/GradDescUtils.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/StopWatch.java\n",
    "\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/TrainingChapter11.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "\n",
    "net.add(Conv2d.builder()\n",
    "        .setKernelShape(new Shape(5, 5))\n",
    "        .optPadding(new Shape(2, 2))\n",
    "        .setFilters(1)\n",
    "        .build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));\n",
    "net.add(Conv2d.builder()\n",
    "        .setKernelShape(new Shape(5, 5))\n",
    "        .setFilters(1)\n",
    "        .build());\n",
    "net.add(Blocks.batchFlattenBlock());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Linear.builder().setUnits(120).build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Linear.builder().setUnits(84).build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(Linear.builder().setUnits(10).build());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "RandomAccessDataset trainDataset = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "RandomAccessDataset testDataset = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, false)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double[] trainLoss;\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "double[] trainAccuracy;\n",
    "\n",
    "public static void train(RandomAccessDataset trainIter, RandomAccessDataset testIter,\n",
    "                             int numEpochs, Trainer trainer) throws IOException, TranslateException {\n",
    "    epochCount = new double[numEpochs];\n",
    "\n",
    "    for (int i = 0; i < epochCount.length; i++) {\n",
    "        epochCount[i] = (i + 1);\n",
    "    }\n",
    "\n",
    "    double avgTrainTimePerEpoch = 0;\n",
    "    Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "\n",
    "    trainer.setMetrics(new Metrics());\n",
    "\n",
    "    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n",
    "\n",
    "    Metrics metrics = trainer.getMetrics();\n",
    "\n",
    "    trainer.getEvaluators().stream()\n",
    "            .forEach(evaluator -> {\n",
    "                evaluatorMetrics.put(\"train_epoch_\" + evaluator.getName(), metrics.getMetric(\"train_epoch_\" + evaluator.getName()).stream()\n",
    "                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "                evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n",
    "                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "            });\n",
    "\n",
    "    avgTrainTimePerEpoch = metrics.mean(\"epoch\");\n",
    "\n",
    "    trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n",
    "    trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n",
    "    testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "\n",
    "    System.out.printf(\"loss %.3f,\" , trainLoss[numEpochs-1]);\n",
    "    System.out.printf(\" train acc %.3f,\" , trainAccuracy[numEpochs-1]);\n",
    "    System.out.printf(\" test acc %.3f\\n\" , testAccuracy[numEpochs-1]);\n",
    "    System.out.printf(\"%.1f examples/sec \\n\", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "Let us have a look at what happens if we invoke this algorithm with default settings, such as a learning rate of $0.3$ and train for $30$ iterations. Note how the training accuracy keeps on increasing while progress in terms of test accuracy stalls beyond a point. The gap between both curves indicates overfitting.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 3,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "float lr = 0.3f;\n",
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "\n",
    "Model model = Model.newInstance(\"Modern LeNet\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Tracker lrt = Tracker.fixed(lr);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public void plotMetrics() {\n",
    "    String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "    Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "    Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "    Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                    trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "    Table data = Table.create(\"Data\").addColumns(\n",
    "        DoubleColumn.create(\"epoch\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "        DoubleColumn.create(\"metrics\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "        StringColumn.create(\"lossLabel\", lossLabel)\n",
    "    );\n",
    "\n",
    "    display(LinePlot.create(\"\", data, \"epoch\", \"metrics\", \"lossLabel\"));\n",
    "}\n",
    "\n",
    "plotMetrics();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## Trackers\n",
    "\n",
    "One way of adjusting the learning rate is to set it explicitly at each step. We could adjust it downward after every epoch (or even after every minibatch), e.g., in a dynamic manner in response to how optimization is progressing. \n",
    "\n",
    "We, however, can't directly change the learning rate with the trainer after it has already been created. What we can do instead is create a tracker to do this for us.\n",
    "\n",
    "When invoked with the number of updates it returns the appropriate value of the learning rate. Let us define a simple one that sets the learning rate to $\\eta = \\eta_0 (t + 1)^{-\\frac{1}{2}}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 7,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public class SquareRootTracker {\n",
    "    float lr;\n",
    "    public SquareRootTracker() {\n",
    "        this(0.1f);\n",
    "    }\n",
    "    public SquareRootTracker(float learningRate) {\n",
    "        this.lr = learningRate;\n",
    "    }\n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        return lr * (float) Math.pow(numUpdate + 1, -0.5);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: This is not a drop in replacement for a standard Learning Rate Tracker (LRT). \n",
    "This is just a simple example to give a better understanding of how they work."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "Let us plot its behavior over a range of values.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public Figure plotLearningRate(int[] epochs, float[] learningRates) {\n",
    "    \n",
    "    String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "    Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "    Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "    Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                    trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "    Table data = Table.create(\"Data\").addColumns(\n",
    "                IntColumn.create(\"epoch\", epochs),\n",
    "                DoubleColumn.create(\"learning rate\", learningRates)\n",
    "    );\n",
    "\n",
    "    return LinePlot.create(\"Learning Rate vs. Epoch\", data, \"epoch\", \"learning rate\");\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "SquareRootTracker tracker = new SquareRootTracker();\n",
    "\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "Now let us see how this plays out for training on Fashion-MNIST. We can't actually do it directly, but we can see how the curve would look theoretically."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "This looks like it works quite a bit better than before. Two things stand out: the curve was rather more smooth than previously. Secondly, there was less overfitting. Unfortunately it is not a well-resolved question as to why certain strategies lead to less overfitting in *theory*. There is some argument that a smaller stepsize will lead to parameters that are closer to zero and thus simpler. However, this does not explain the phenomenon entirely since we do not really stop early but simply reduce the learning rate gently.\n",
    "\n",
    "## Policies\n",
    "\n",
    "While we cannot possibly cover the entire variety of learning rate trackers, we attempt to give a brief overview of popular policies below. Common choices are polynomial decay and piecewise constant schedules. Beyond that, cosine learning rate schedules have been found to work well empirically on some problems. Lastly, on some problems it is beneficial to warm up the optimizer prior to using large learning rates.\n",
    "\n",
    "### Factor Tracker\n",
    "\n",
    "One alternative to a polynomial decay would be a multiplicative one, that is $\\eta_{t+1} \\leftarrow \\eta_t \\cdot \\alpha$ for $\\alpha \\in (0, 1)$. To prevent the learning rate from decaying beyond a reasonable lower bound the update equation is often modified to $\\eta_{t+1} \\leftarrow \\mathop{\\mathrm{max}}(\\eta_{\\mathrm{min}}, \\eta_t \\cdot \\alpha)$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 13,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public class DemoFactorTracker {\n",
    "    float baseLr;\n",
    "    float stopFactorLr;\n",
    "    float factor;\n",
    "    public DemoFactorTracker(float factor, float stopFactorLr, float baseLr) {\n",
    "        this.factor = factor;\n",
    "        this.stopFactorLr = stopFactorLr;\n",
    "        this.baseLr = baseLr;\n",
    "    }\n",
    "    public DemoFactorTracker() {\n",
    "        this(1f, (float) 1e-7, 0.1f);\n",
    "    }\n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        return lr * (float) Math.pow(numUpdate + 1, -0.5);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DemoFactorTracker tracker = new DemoFactorTracker(0.9f, (float) 1e-2, 2);\n",
    "\n",
    "numEpochs = 50;\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "This can also be accomplished by a built-in scheduler in DJL via the `LearningRateTracker.factorTracker()` builder. It takes a few more parameters, such as warmup period, warmup mode (linear or constant), the maximum number of desired updates, etc.; Going forward we will use the built-in schedulers as appropriate and only explain their functionality here. \n",
    "\n",
    "### Multi Factor Scheduler\n",
    "\n",
    "A common strategy for training deep networks is to keep the learning rate piecewise constant and to decrease it by a given amount every so often. That is, given a set of times when to decrease the rate, such as $s = \\{5, 10, 20\\}$ decrease $\\eta_{t+1} \\leftarrow \\eta_t \\cdot \\alpha$ whenever $t \\in s$. Assuming that the values are halved at each step we can implement this as follows.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 15,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "MultiFactorTracker tracker = Tracker.multiFactor()\n",
    "        .setSteps(new int[]{5, 30})\n",
    "        .optFactor(0.5f)\n",
    "        .setBaseValue(0.5f)\n",
    "        .build();\n",
    "\n",
    "numEpochs = 10;\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewValue(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "The intuition behind this piecewise constant learning rate schedule is that one lets optimization proceed until a stationary point has been reached in terms of the distribution of weight vectors. Then (and only then) do we decrease the rate such as to obtain a higher quality proxy to a good local minimum. The example below shows how this can produce ever slightly better solutions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 17,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "\n",
    "Model model = Model.newInstance(\"Modern LeNet\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(tracker).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);\n",
    "plotMetrics();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "            DoubleColumn.create(\"epoch\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "            DoubleColumn.create(\"metrics\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "            StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "LinePlot.create(\"\", data, \"epoch\", \"metrics\", \"lossLabel\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "### Cosine Tracker\n",
    "\n",
    "A rather perplexing heuristic was proposed by :cite:`Loshchilov.Hutter.2016`. It relies on the observation that we might not want to decrease the learning rate too drastically in the beginning and moreover, that we might want to \"refine\" the solution in the end using a very small learning rate. This results in a cosine-like tracker with the following functional form for learning rates in the range $t \\in [0, T]$.\n",
    "\n",
    "$$\\eta_t = \\eta_T + \\frac{\\eta_0 - \\eta_T}{2} \\left(1 + \\cos(\\pi t/T)\\right)$$\n",
    "\n",
    "Here $\\eta_0$ is the initial learning rate, $\\eta_T$ is the target rate at time $T$. Furthermore, for $t > T$ we simply pin the value to $\\eta_T$ without increasing it again. In the following example, we set the max update step $T = 20$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class DemoCosineTracker {\n",
    "    float baseLr;\n",
    "    float finalLr;\n",
    "    int maxUpdate;\n",
    "    public DemoCosineTracker() {\n",
    "        this(0.5f, 0.01f, 20);\n",
    "    }\n",
    "    public DemoCosineTracker(float baseLr, float finalLr, int maxUpdate) {\n",
    "        this.baseLr = baseLr;\n",
    "        this.finalLr = finalLr;\n",
    "        this.maxUpdate = maxUpdate;\n",
    "    }\n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        if (numUpdate > maxUpdate) {\n",
    "            return finalLr;\n",
    "        }\n",
    "        // Scale the curve to smoothly transition\n",
    "        float step = (baseLr - finalLr) / 2 * (1 + (float) Math.cos(Math.PI * numUpdate / maxUpdate));\n",
    "        return finalLr + step;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 19,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "DemoCosineTracker tracker = new DemoCosineTracker(0.5f, 0.01f, 20);\n",
    "\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "In the context of computer vision this schedule *can* lead to improved results. Note, though, that such improvements are not guaranteed (as can be seen below).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 21,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "CosineTracker cosineTracker = Tracker.cosine()\n",
    "                            .setBaseValue(0.5f)\n",
    "                            .optFinalValue(0.01f)\n",
    "                            .setMaxUpdates(20)\n",
    "                            .build();\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(cosineTracker).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 22
   },
   "source": [
    "### Warmup\n",
    "\n",
    "In some cases initializing the parameters is not sufficient to guarantee a good solution. This particularly a problem for some advanced network designs that may lead to unstable optimization problems. We could address this by choosing a sufficiently small learning rate to prevent divergence in the beginning. Unfortunately this means that progress is slow. Conversely, a large learning rate initially leads to divergence.\n",
    "\n",
    "A rather simple fix for this dilemma is to use a warmup period during which the learning rate *increases* to its initial maximum and to cool down the rate until the end of the optimization process. For simplicity one typically uses a linear increase for this purpose. This leads to a schedule of the form indicated below.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class CosineWarmupTracker {\n",
    "    float baseLr;\n",
    "    float finalLr;\n",
    "    int maxUpdate;\n",
    "    int warmUpSteps;\n",
    "    float warmUpBeginValue;\n",
    "    float warmUpFinalValue;\n",
    "    \n",
    "    public CosineWarmupTracker() {\n",
    "        this(0.5f, 0.01f, 20, 5);\n",
    "    }\n",
    "    \n",
    "    public CosineWarmupTracker(float baseLr, float finalLr, int maxUpdate, int warmUpSteps) {\n",
    "        this.baseLr = baseLr;\n",
    "        this.finalLr = finalLr;\n",
    "        this.maxUpdate = maxUpdate;\n",
    "        this.warmUpSteps = 5;\n",
    "        this.warmUpBeginValue = 0f;\n",
    "    }\n",
    "    \n",
    "    public float getNewLearningRate(int numUpdate) {\n",
    "        if (numUpdate <= warmUpSteps) {\n",
    "            return getWarmUpValue(numUpdate);\n",
    "        }\n",
    "        if (numUpdate > maxUpdate) {\n",
    "            return finalLr;\n",
    "        }\n",
    "        // Scale the cosine curve to fit smoothly with the warmup steps\n",
    "        float step = (baseLr - finalLr) / 2 * (1 + \n",
    "            (float) Math.cos(Math.PI * (numUpdate - warmUpSteps) / (maxUpdate - warmUpSteps)));\n",
    "        return finalLr + step;\n",
    "    }\n",
    "    \n",
    "    public float getWarmUpValue(int numUpdate) {\n",
    "        // Linear warmup\n",
    "        return warmUpBeginValue + (baseLr - warmUpBeginValue) * numUpdate / warmUpSteps;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 23,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "CosineWarmupTracker tracker = new CosineWarmupTracker(0.5f, 0.01f, 20, 5);\n",
    "\n",
    "int[] epochs = new int[numEpochs];\n",
    "float[] learningRates = new float[numEpochs];\n",
    "for (int i = 0; i < numEpochs; i++) {\n",
    "    epochs[i] = i;\n",
    "    learningRates[i] = tracker.getNewLearningRate(i);\n",
    "}\n",
    "\n",
    "plotLearningRate(epochs, learningRates);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 24
   },
   "source": [
    "Note that the network converges better initially (in particular observe the performance during the first 5 epochs).\n",
    "\n",
    "Additionally, we still use a total of 20 max updates, but the 1st\n",
    "5 are dedicated to the warmup steps. The cosine curve will then be\n",
    "squeezed into the 15 steps relative to the earlier 20 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 25,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "CosineTracker cosineTracker = Tracker.cosine()\n",
    "        .setBaseValue(0.5f)\n",
    "        .optFinalValue(0.01f)\n",
    "        .setMaxUpdates(15)\n",
    "        .build();\n",
    "\n",
    "WarmUpTracker warmupCosine = Tracker.warmUp()\n",
    "        .optWarmUpSteps(5)\n",
    "        .setMainTracker(cosineTracker)\n",
    "        .build();\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(warmupCosine).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "train(trainDataset, testDataset, numEpochs, trainer);\n",
    "plotMetrics();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 26
   },
   "source": [
    "Warmup can be applied to any scheduler (not just cosine). For a more detailed discussion of learning rate schedules and many more experiments see also :cite:`Gotmare.Keskar.Xiong.ea.2018`. In particular they find that a warmup phase limits the amount of divergence of parameters in very deep networks. This makes intuitively sense since we would expect significant divergence due to random initialization in those parts of the network that take the most time to make progress in the beginning.\n",
    "\n",
    "## Summary\n",
    "\n",
    "* Decreasing the learning rate during training can lead to improved accuracy and (most perplexingly) reduced overfitting of the model.\n",
    "* A piecewise decrease of the learning rate whenever progress has plateaued is effective in practice. Essentially this ensures that we converge efficiently to a suitable solution and only then reduce the inherent variance of the parameters by reducing the learning rate.\n",
    "* Cosine schedulers are popular for some computer vision problems.\n",
    "* A warmup period before optimization can prevent divergence.\n",
    "* Optimization serves multiple purposes in deep learning. Besides minimizing the training objective, different choices of optimization algorithms and learning rate scheduling can lead to rather different amounts of generalization and overfitting on the test set (for the same amount of training error).\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Experiment with the optimization behavior for a given fixed learning rate. What is the best model you can obtain this way?\n",
    "1. How does convergence change if you change the exponent of the decrease in the learning rate?\n",
    "1. Apply the cosine scheduler to large computer vision problems, e.g., training ImageNet. How does it affect performance relative to other schedulers?\n",
    "1. How long should warmup last?\n",
    "1. Can you connect optimization and sampling? Start by using results from :cite:`Welling.Teh.2011` on Stochastic Gradient Langevin Dynamics.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
